import os,sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from base import Memory, env0, env1
import numpy as np
import time

class DQN(nn.Module):
    def __init__(self,n_state,n_action):
        super(DQN, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(n_state,20),
            nn.ReLU(),
            nn.Linear(20,n_action)
        )

        for m in self.seq.parameters():
            if isinstance(m,nn.Linear):
                nn.init.normal_(m.weight,mean=0,std=0.1)

    def forward(self,state):
        return self.seq(state)

class Model(Memory):
    def __init__(self,n_state,n_action,
                 lr=1e-3,epsilon=0.5,epsilon_decay=0.9996,
                 capacity=10000,logs='./logs'):
        super(Model, self).__init__(capacity,logs)
        self.n_state = n_state
        self.n_action = n_action
        self.eval_net = DQN(n_state,n_action).to(self.device)
        self.target_net = DQN(n_state,n_action).to(self.device)

        self.opt = torch.optim.Adam(self.eval_net.parameters(),lr=lr)
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay

    def optimize(self,batch):
        s = torch.tensor(batch.s,dtype=torch.float).to(self.device)
        a = torch.tensor(batch.a,dtype=torch.long).to(self.device)
        r = torch.tensor(batch.r,dtype=torch.float).to(self.device)
        s_ = torch.tensor(batch.s_,dtype=torch.float).to(self.device)
        done = torch.tensor(batch.done,dtype=torch.float).to(self.device)

        q_predicted = self.eval_net(s).gather(dim=1,index=a)
        q_expected,_ = torch.max(self.target_net(s_),dim=1,keepdim=True)
        q_expected = r + (1-done)*q_expected
        loss = F.mse_loss(q_predicted,q_expected).mean()

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        self.writer.add_scalar('loss',loss.item(),self.step)
        return

    def choose_action(self,state):
        self.epsilon *= self.epsilon_decay
        if np.random.uniform(0,1)<self.epsilon:
            action = np.random.randint(0,self.n_action)
        else:
            state = torch.tensor(state,dtype=torch.float).unsqueeze(0)
            action = torch.argmax(self.eval_net(state),dim=1).item()
        self.writer.add_scalar('action',action,self.step)
        return action

    def hard_update(self):
        self.target_net.load_state_dict(OrderedDict(self.eval_net.state_dict()))

    def save_model(self,ckpt_dir='./logs'):
        torch.save(self.eval_net.parameters(),os.path.join(ckpt_dir,'model.pth'))

    def load_model(self,ckpt_dir='./logs'):
        if os.path.exists(os.path.join(ckpt_dir,'model.pth')):
            self.eval_net.load_state_dict(torch.load(os.path.join(ckpt_dir,'model.pth')))
        self.hard_update()

if __name__=='__main__':
    epoch = 1000
    max_step_per_epoch =1000
    update_epoch = 10
    batch_size = 64
    env = env0
    env.seed(int(time.time()))

    model = Model(env.observation_space.shape[0],env.action_space.n)
    model.load_model()

    print("Sampling...")
    count = 0
    state = env.reset()
    while count < model.capacity:
        action = model.choose_action(state)
        state_, reward, done, info = env.step(action)
        model.put_transition(state, [action], [reward], state_, [done])
        state = state_
        count += 1
        if done:
            state = env.reset()

    count = model.capacity
    for ep in range(epoch):
        state = env.reset()
        for st in range(max_step_per_epoch):
            action = model.choose_action(state)
            state_,reward,done,info = env.step(action)
            model.put_transition(state,[action],[reward],state_,[done])
            state = state_

            batch = model.get_transition(batch_size=batch_size)
            model.optimize(batch)
            env.render()
            print('Epoch:[{}/{}],step:[{}/{}]'.format(ep + 1, epoch, st + 1, max_step_per_epoch))

            count += 1
            if done:
                model.writer.add_scalar('ep_r',st+1,count)
                break
        if (ep+1) % update_epoch == 0:
            model.hard_update()
